package com.hanlp.tokenizer.core;

import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.corpus.tag.Nature;
import com.hankcs.hanlp.dictionary.py.Pinyin;
import com.hankcs.hanlp.dictionary.py.PinyinDictionary;
import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.hanlp.tokenizer.NatureAttribute;
import com.hanlp.tokenizer.dictionary.SynonymDictionary;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
import org.elasticsearch.SpecialPermission;

import java.io.BufferedReader;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedTransferQueue;
import java.util.stream.Collectors;

/**
 * <p></p>
 *
 * <PRE>
 * <BR>	修改记录
 * <BR>-----------------------------------------------
 * <BR>	修改日期			修改人			修改内容
 * </PRE>
 *
 * @author youshipeng
 * @since 1.0
 * @version 1.0
 */
public class MyTokenizer extends Tokenizer {

    private final CharTermAttribute charTermAttribute = addAttribute(CharTermAttribute.class);
    private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class);
    private final PositionIncrementAttribute positionIncrementAttribute = addAttribute(PositionIncrementAttribute.class);
    private final NatureAttribute natureAttribute = addAttribute(NatureAttribute.class);
    private final TypeAttribute typeAttribute = addAttribute(TypeAttribute.class);

    private static Nature auxiliary;
    private static Segment NLPSegment;
    private static Segment indexSegment;

    static {
        SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
            sm.checkPermission(new SpecialPermission());
        }
        AccessController.doPrivileged(new PrivilegedAction<Void>() {
            @Override
            public Void run() {
                auxiliary = Nature.create("auxiliary");
                return null;
            }
        });
        AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
            NLPSegment = HanLP.newSegment().enablePartOfSpeechTagging(true) //词性标注
                    .enableOffset(true) //计算偏移量
                    .enableNameRecognize(true) // 中文人名识别
//                    .enableJapaneseNameRecognize(true) // 日本人名识别
                    .enableNumberQuantifierRecognize(true) //数量词识别
                    .enableOrganizationRecognize(true); //机构名识别
//                    .enableTranslatedNameRecognize(true); //音译人名识别;
            indexSegment = HanLP.newSegment().enableIndexMode(1).enablePartOfSpeechTagging(true) //词性标注
                    .enableOffset(true); //计算偏移量;
            // 在此处显示调用一下分词，使得加载词典、缓存词典的操作可以正确执行
            System.out.println(NLPSegment.seg("HanLP中文分词工具包！"));
            System.out.println(indexSegment.seg("HanLP中文分词工具包！"));
            return null;
        });
    }

    private SegmentationType segmentationType;
    private BufferedReader reader = null;
    private final Queue<MyTerm> terms = new LinkedTransferQueue<>();
    private final Queue<String> tokens = new LinkedTransferQueue<>();

    private Term current;

    public MyTokenizer(SegmentationType segmentationType) {
        this.segmentationType = segmentationType;
    }

    @Override
    public boolean incrementToken() throws IOException {
        String token = getToken();
        if (token != null) {
            charTermAttribute.setEmpty().append(token);
            return true;
        }
        return false;
    }

    private Segment getSegment() {
        Segment segment;
        if (segmentationType.isIndexSegmentationType()) {
            segment = indexSegment;
        } else {
            segment = NLPSegment;
        }
        return segment;
    }

    private List<MyTerm> seg(String text) {
        Set<MyTerm> terms = segAndConvert(NLPSegment, text);
        if (segmentationType.isIndexSegmentationType()) {
            Set<MyTerm> indexTerms = segAndConvert(indexSegment, text);
            terms.addAll(indexTerms);
        }

        ArrayList<MyTerm> finalResults = new ArrayList<>(terms);
        Comparator<MyTerm> offsetComparator = Comparator.comparing(t -> t.offset);
        Comparator<MyTerm> lengthComparator = Comparator.comparing(t -> t.word.length(), Comparator.reverseOrder());
        finalResults.sort(offsetComparator.thenComparing(lengthComparator));
        return finalResults;
    }

    private Set<MyTerm> segAndConvert(Segment segment, String text) {
        List<Term> terms = segment.seg(text);
        return terms.stream().map(MyTerm::new).collect(Collectors.toSet());
    }

    private MyTerm getTerm() throws IOException {
        MyTerm term = terms.poll();
        if (term == null) {
            if (reader == null) {
                reader = new BufferedReader(input);
            }
            int length;
            StringBuilder sb = new StringBuilder();
            char[] buffer = new char[500];
            while (-1 != (length = reader.read(buffer, 0, buffer.length))) {
                sb.append(buffer, 0, length);
            }
            terms.addAll(seg(sb.toString()));
            term = terms.poll();
        }
        return term;
    }

    private String getToken() throws IOException {
        String token = tokens.poll();
        if(token == null) {
            Term term = getTerm();
            if (term != null) {
                int positionIncrement = 1;
                offsetAttribute.setOffset(term.offset, term.offset + term.word.length());
                positionIncrementAttribute.setPositionIncrement(positionIncrement);
                natureAttribute.setNature(term.nature);
                typeAttribute.setType(term.nature.toString());

//                if (CustomDictionary.contains(term.word)) {
//                    CoreDictionary.Attribute attribute = CustomDictionary.get(term.word);
//                    if (attribute != null && attribute.nature.length > 0) {
//                        natureAttribute.setNature(attribute.nature[0]);
//                        typeAttribute.setType(attribute.nature[0].toString());
//                    }
//                }

                if (segmentationType == SegmentationType.synonym) {
                    SynonymDictionary.get(term.word).forEach(tokens::offer);
                } else if (segmentationType == SegmentationType.pinyin) {
                    tokens.offer(term.word);
                    String acronym = wordToPinyin(term.word, true);
                    String whole = wordToPinyin(term.word, false);
                    if (!acronym.isEmpty()) {
                        tokens.offer(acronym);
                    }
                    if (!whole.isEmpty()) {
                        tokens.offer(whole);
                    }
                } else if (segmentationType == SegmentationType.pinyin_polyphone) {
                    tokens.offer(term.word);
                    wordToAllPinyin(term.word, true).forEach(tokens::offer);
                    wordToAllPinyin(term.word, false).forEach(tokens::offer);
                } else {
                    tokens.offer(term.word);
                }

                token = tokens.poll();
            }
        }
        return token;
    }

    public static String wordToPinyin(String text, boolean isAcronym) {
        List<Pinyin> pinyinList = PinyinDictionary.convertToPinyin(text, true);
        StringBuilder sb = new StringBuilder();
        for (Pinyin pinyin : pinyinList) {
            if (pinyin != Pinyin.none5) {
                if (isAcronym) {
                    sb.append(pinyin.getFirstChar());
                } else {
                    sb.append(pinyin.getPinyinWithoutTone());
                }
            }
        }
        return sb.toString();
    }

    private static Set<String> wordToAllPinyin(String word, boolean isAcronym) {
        if (word == null || word.isEmpty()) {
            return Collections.emptySet();
        }
        Set<String> pinyins = new HashSet<>();
        for (char c : word.toCharArray()) {
            Pinyin[] charPinyins = PinyinDictionary.get(String.valueOf(c));
            if (charPinyins == null || charPinyins.length == 0) {
                continue;
            }
            if (pinyins.isEmpty()) {
                pinyins = Arrays.stream(charPinyins).map(p -> {
                    if (isAcronym) {
                        return String.valueOf(p.getFirstChar());
                    } else {
                        return p.getPinyinWithoutTone();
                    }
                }).collect(Collectors.toSet());
                continue;
            }
            Set<String> newPinyins = new HashSet<>();
            for (Pinyin charPinyin : charPinyins) {
                for (String pinyin : pinyins) {
                    String newPinyin = pinyin;
                    if (isAcronym) {
                        newPinyin += charPinyin.getFirstChar();
                    } else {
                        newPinyin += charPinyin.getPinyinWithoutTone();
                    }
                    newPinyins.add(newPinyin);
                }
            }
            pinyins = newPinyins;
        }
        return pinyins;
    }

    private void decomposition(Term term) {
        if (term.length() == 1 || term.nature == auxiliary) {
            return;
        }
        for (int i = 0 ; i < term.length(); i++) {
            tokens.offer(term.word.substring(i, i + 1));
        }
    }

    private String getIndexToken() throws IOException {
        String token = tokens.poll();
        if (token == null) {
            Term term = getTerm();
            if (term != null) {
                current = term;
                int positionIncrement = 1;
                offsetAttribute.setOffset(term.offset, term.offset + term.word.length());
                positionIncrementAttribute.setPositionIncrement(positionIncrement);
                natureAttribute.setNature(term.nature);
                typeAttribute.setType(term.nature.toString());

                tokens.offer(term.word);
                decomposition(term);
                token = tokens.poll();
            }
        } else {
            offsetAttribute.setOffset(current.offset + current.length() - tokens.size() - 1, current.offset + current.length() - tokens.size());
            positionIncrementAttribute.setPositionIncrement(1);
            natureAttribute.setNature(auxiliary);
            typeAttribute.setType(auxiliary.toString());
        }
        return token;
    }
}